{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from pytorch_tabular.utils import make_mixed_dataset, print_metrics\n",
    "\n",
    "# %load_ext autoreload\n",
    "# %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "data, cat_col_names, num_col_names = make_mixed_dataset(\n",
    "    task=\"regression\", n_samples=10000, n_features=20, n_categories=4\n",
    ")\n",
    "train, test = train_test_split(data, random_state=42)\n",
    "train, val = train_test_split(train, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "true"
   },
   "source": [
    "# Importing the Library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from pytorch_tabular import TabularModel\n",
    "from pytorch_tabular.models import CategoryEmbeddingModelConfig\n",
    "from pytorch_tabular.config import (\n",
    "    DataConfig,\n",
    "    OptimizerConfig,\n",
    "    TrainerConfig,\n",
    "    ExperimentConfig,\n",
    "    ModelConfig,\n",
    ")\n",
    "from pytorch_tabular.models import BaseModel\n",
    "from pytorch_tabular.models.common.layers import Embedding1dLayer\n",
    "from pytorch_tabular.models.common.heads import LinearHeadConfig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Defining a Custom Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from omegaconf import DictConfig\n",
    "from typing import Dict\n",
    "from dataclasses import dataclass, field"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "PyTorch Tabular is very easy to extend and infinitely customizable. All the models that have been implemented in PyTorch Tabular inherits an Abstract Class `BaseModel` which is in fact a PyTorchLightning Model.\n",
    "\n",
    "It handles all the major functions like decoding the config params and setting up the loss and metrics. It also calculates the Loss and metrics and feeds it back to the PyTorch Lightning Trainer which does the back-propagation.\n",
    "\n",
    "If we look at the anatomy of a PyTorch Tabular model, there are three main components:\n",
    "\n",
    "1. **Embedding Layer** \n",
    "2. **Backbone**\n",
    "3. **Head**\n",
    "\n",
    "**Embedding Layer** takes the input from the dataloader, which is a dictionary with `categorical` and `continuous` tensors under those keys. The Embedding Layer converts this dictionary to a single tensor, handling the categorical tensors the right way. There are two Embedding Layers already implemented, EmbeddingLayer1d, EmbeddingLayer2d.\n",
    "\n",
    "**Backbone** is the main model architecture.\n",
    "\n",
    "**Head** is the linear layers which takes the output of the backbone and converts it into the output we desire.\n",
    "\n",
    "To encapsulate and enforce this structure, `BaseModel` requires us to define three `property` methods:\n",
    "1. `def embedding_layer(self)`\n",
    "2. `def backbone(self)`\n",
    "3. `def head(self)`\n",
    "\n",
    "There is another method `def _build_network(self)` which also needs to be defined mandatorily. We can use this method to define the embedding, backbone, head, and whatever other layers or components we need to use in the model.\n",
    "\n",
    "For standard Feed Forward layers as the head, we can also use a handly method in `BaseModel` called `_get_head_from_config` which will use the `head` and `head_config` you have set in the `ModelConfig` to initialize the right head automatically for you.\n",
    "\n",
    "An example of using it:\n",
    "```python\n",
    "self._head = self._get_head_from_config()\n",
    "```\n",
    "\n",
    "For standard flows, the `forward` method that is already defined would be enough.\n",
    "```python\n",
    "def forward(self, x: Dict):\n",
    "    x = self.embed_input(x) # Embeds the input dictionary and returns a tensor\n",
    "    x = self.compute_backbone(x) # Takes the tensor input and does representation learning\n",
    "    return self.compute_head(x) # transforms the backbone output to desired output, applies target range if necessary, and packs the output in the desired format.\n",
    "```\n",
    "What this allows us to do is to define any standard PyTorch model and use it as the **backbone**, and then use the rest of the PyTorch Tabular machinery to train, evaluate and log the model.\n",
    "\n",
    "While this is the bare minimum, you can redefine or use any of the Pytorch Lightning standard methods to tweak your model and training to your liking. The real beauty of this setup is that how much customization you need to do is really upto you. For standard workflows, you can change the minimum. And for highly unsual models, you can re-define any of the method in `BaseModel` (as long as the interfaces remain unchanged).\n",
    "\n",
    "\n",
    "In addition to the model, you will also need to define a config. Configs are python dataclasses and should inherit `ModelConfig` and will have all the parameters of the ModelConfig. by default. Any additional parameter should be defined in the dataclass. \n",
    "\n",
    "\n",
    "**Key things to note:**\n",
    "\n",
    "1. All the different parameters in the different configs(like TrainerConfig, OptimizerConfig, etc) are all available in `config` before calling `super()` and in `self.hparams` after.\n",
    "2. the input batch at the `forward` method is a dictionary with keys `continuous` and `categorical`\n",
    "3. In the `_build_network` method, save every component that you want access in the `forward` to `self`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "# Define a Config class to hold the hyperparameters of the model\n",
    "# We need to inherit ModelConfig so that default parameters like learning rate, head, head_config, etc.\n",
    "# are present in the config\n",
    "@dataclass\n",
    "class MyAwesomeModelConfig(ModelConfig):\n",
    "    use_batch_norm: bool = True\n",
    "\n",
    "\n",
    "# Define the core model as a pure PyTorch model\n",
    "# The forward method takes in a tensor and outputs a tensor\n",
    "\n",
    "\n",
    "class MyAwesomeRegressionModel(nn.Module):\n",
    "    def __init__(self, hparams: DictConfig, **kwargs):\n",
    "        super().__init__()\n",
    "        self.hparams = (\n",
    "            hparams  # Save the config to be accessed in the model elsewhere\n",
    "        )\n",
    "        self._build_network()\n",
    "\n",
    "    def _build_network(self):\n",
    "        # Continuous and Categorical Dimensions are precalculated and stored in the config (InferredConfig)\n",
    "        inp_dim = self.hparams.embedded_cat_dim + self.hparams.continuous_dim\n",
    "        self.linear_layer_1 = nn.Linear(inp_dim, 200)\n",
    "        self.linear_layer_2 = nn.Linear(inp_dim + 200, 70)\n",
    "        self.linear_layer_3 = nn.Linear(inp_dim + 70, 70)\n",
    "        self.input_batch_norm = nn.BatchNorm1d(self.hparams.continuous_dim)\n",
    "        if self.hparams.use_batch_norm:\n",
    "            self.batch_norm_2 = nn.BatchNorm1d(inp_dim + 200)\n",
    "            self.batch_norm_3 = nn.BatchNorm1d(inp_dim + 70)\n",
    "        self.dropout = nn.Dropout(0.3)\n",
    "        self.output_dim = 70\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        inp = x\n",
    "        x = F.relu(self.linear_layer_1(x))\n",
    "        x = self.dropout(x)\n",
    "        x = torch.cat([x, inp], 1)\n",
    "        if self.hparams.use_batch_norm:\n",
    "            x = self.batch_norm_1(x)\n",
    "        x = F.relu(self.linear_layer_2(x))\n",
    "        x = self.dropout(x)\n",
    "        x = torch.cat([x, inp], 1)\n",
    "        if self.hparams.use_batch_norm:\n",
    "            x = self.batch_norm_3(x)\n",
    "        x = self.linear_layer_3(x)\n",
    "        return x\n",
    "\n",
    "    # It is also good practice to bundle the embedding scheme along with the core model\n",
    "    # This is so that the model encapsulates it's requirement of how to embed the input\n",
    "    # dictionary of categorical and continuous tensors and how to combine them into a single tensor.\n",
    "    # We can either use one of the predefined embedding layers in PyTorch Tabular\n",
    "    # Or define a custom embedding layer as well. Check Embedding1dLayer implementation\n",
    "    # on how to implement an embedding layer\n",
    "    def _build_embedding_layer(self):\n",
    "        return Embedding1dLayer(\n",
    "            continuous_dim=self.hparams.continuous_dim,\n",
    "            categorical_embedding_dims=self.hparams.embedding_dims,\n",
    "            embedding_dropout=self.hparams.embedding_dropout,\n",
    "            batch_norm_continuous_input=self.hparams.batch_norm_continuous_input,\n",
    "        )\n",
    "\n",
    "\n",
    "# Define the PyTorch Tabular model by inheriting BaseModel\n",
    "\n",
    "\n",
    "class MyAwesomeRegressionPTModel(BaseModel):\n",
    "    def __init__(self, config: DictConfig, **kwargs):\n",
    "        # Save any attribute that you need in _build_network before calling super()\n",
    "        # After the super() call, the config will be saved as `hparams`\n",
    "        super().__init__(config, **kwargs)\n",
    "\n",
    "    def _build_network(self):\n",
    "        # Backbone - Initialize the PyTorch model we defined earlier\n",
    "        self._backbone = MyAwesomeRegressionModel(self.hparams)\n",
    "        # Initializing Embedding Layer using the method we defined in the backbone\n",
    "        self._embedding_layer = self._backbone._build_embedding_layer()\n",
    "        # Head - we can use the helper method to initialize standard linear head\n",
    "        self.head = self._get_head_from_config()\n",
    "\n",
    "    # Now define the property methods which the BaseModel requires you to override\n",
    "    @property\n",
    "    def backbone(self):\n",
    "        return self._backbone\n",
    "\n",
    "    @property\n",
    "    def embedding_layer(self):\n",
    "        return self._embedding_layer\n",
    "\n",
    "    @property\n",
    "    def head(self):\n",
    "        return self._head\n",
    "\n",
    "    # For more customizations, we can override forward, compute_backbone, compute_head etc."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Define the Configs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "There is one deviation from the normal when we create a TabularModel object with the configs. Earlier the model was inferred from the config and initialized autmatically. But here, we have to use the `model_callable` parameter of the TabularModel and pass in the model class(not the initialized object)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">26</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">16:43:07</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">659</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">134</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m26\u001b[0m \u001b[1;92m16:43:07\u001b[0m,\u001b[1;36m659\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m134\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "data_config = DataConfig(\n",
    "    target=[\n",
    "        \"target\"\n",
    "    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names,\n",
    ")\n",
    "trainer_config = TrainerConfig(\n",
    "    auto_lr_find=True,  # Runs the LRFinder to automatically derive a learning rate\n",
    "    batch_size=1024,\n",
    "    max_epochs=100,\n",
    "    accelerator=\"auto\",  # can be 'cpu','gpu', 'tpu', or 'ipu'\n",
    "    devices=-1,  # -1 means use all available\n",
    ")\n",
    "optimizer_config = OptimizerConfig()\n",
    "\n",
    "head_config = LinearHeadConfig(\n",
    "    layers=\"32\",  # Number of nodes in each layer\n",
    "    activation=\"ReLU\",  # Activation between each layers\n",
    "    dropout=0.1,\n",
    "    initialization=\"kaiming\",\n",
    ").__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)\n",
    "\n",
    "model_config = MyAwesomeModelConfig(\n",
    "    task=\"regression\",\n",
    "    use_batch_norm=False,\n",
    "    head=\"LinearHead\",  # Linear Head\n",
    "    head_config=head_config,  # Linear Head Config\n",
    "    learning_rate=1e-3,\n",
    ")\n",
    "\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    model_callable=MyAwesomeRegressionPTModel,  # When using custom model, we need to use the model_callable parameter\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Training the Model \n",
    "\n",
    "The rest of the process is business-as-usual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "Collapsed": "false",
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">26</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">16:43:10</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">447</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">506</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the DataLoaders                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m26\u001b[0m \u001b[1;92m16:43:10\u001b[0m,\u001b[1;36m447\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m506\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">26</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">16:43:10</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">455</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_datamodul<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e:431</span><span style=\"font-weight: bold\">}</span> - INFO - Setting up the datamodule for          \n",
       "regression task                                                                                                    \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m26\u001b[0m \u001b[1;92m16:43:10\u001b[0m,\u001b[1;36m455\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:431\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for          \n",
       "regression task                                                                                                    \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">26</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">16:43:10</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">490</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">556</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Model: Model                  \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m26\u001b[0m \u001b[1;92m16:43:10\u001b[0m,\u001b[1;36m490\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m556\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: Model                  \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">26</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">16:43:10</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">541</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">322</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m26\u001b[0m \u001b[1;92m16:43:10\u001b[0m,\u001b[1;36m541\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m322\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">26</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">16:43:10</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">722</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">612</span><span style=\"font-weight: bold\">}</span> - INFO - Auto LR Find Started                        \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m26\u001b[0m \u001b[1;92m16:43:10\u001b[0m,\u001b[1;36m722\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m612\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started                        \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1307290dd4cb40328ae118c0a6d430bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LR finder stopped early after 99 steps due to diverging loss.\n",
      "Learning rate set to 0.008317637711026709\n",
      "Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_c0a3e124-98c0-4ff0-80c4-0287c2f67b80.ckpt\n",
      "Restored all states from the checkpoint at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_c0a3e124-98c0-4ff0-80c4-0287c2f67b80.ckpt\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">26</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">16:43:14</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">083</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">625</span><span style=\"font-weight: bold\">}</span> - INFO - Suggested LR: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.008317637711026709</span>. For plot\n",
       "and detailed analysis, use `find_learning_rate` method.                                                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m26\u001b[0m \u001b[1;92m16:43:14\u001b[0m,\u001b[1;36m083\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m625\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m0.008317637711026709\u001b[0m. For plot\n",
       "and detailed analysis, use `find_learning_rate` method.                                                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">26</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">16:43:14</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">085</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">634</span><span style=\"font-weight: bold\">}</span> - INFO - Training Started                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m26\u001b[0m \u001b[1;92m16:43:14\u001b[0m,\u001b[1;36m085\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m634\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                     </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ MyAwesomeRegressionModel │ 28.8 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding1dLayer         │     92 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ head             │ LinearHead               │  2.3 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ MSELoss                  │      0 │\n",
       "└───┴──────────────────┴──────────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                    \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ MyAwesomeRegressionModel │ 28.8 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer         │     92 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head             │ LinearHead               │  2.3 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ MSELoss                  │      0 │\n",
       "└───┴──────────────────┴──────────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 31.2 K                                                                                           \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 31.2 K                                                                                               \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 31.2 K                                                                                           \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 31.2 K                                                                                               \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e63e8c1453de42868ff7488fb03a02bd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">26</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">16:43:16</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">563</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">645</span><span style=\"font-weight: bold\">}</span> - INFO - Training the model completed                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m26\u001b[0m \u001b[1;92m16:43:16\u001b[0m,\u001b[1;36m563\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m645\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">26</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">16:43:16</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">564</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1491</span><span style=\"font-weight: bold\">}</span> - INFO - Loading the best model                     \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m26\u001b[0m \u001b[1;92m16:43:16\u001b[0m,\u001b[1;36m564\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1491\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model                     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f08ebab7e90>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tabular_model.fit(train=train, validation=val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "476ba4351bc34130ba5f880370813a49",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    509.21295166015625     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">  test_mean_squared_error  </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    509.21295166015625     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   509.21295166015625    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m test_mean_squared_error \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   509.21295166015625    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "result = tabular_model.evaluate(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>target_prediction</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>6252</th>\n",
       "      <td>-111.886528</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4684</th>\n",
       "      <td>-169.181137</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1731</th>\n",
       "      <td>272.100281</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4742</th>\n",
       "      <td>12.167077</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4521</th>\n",
       "      <td>56.998451</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      target_prediction\n",
       "6252        -111.886528\n",
       "4684        -169.181137\n",
       "1731         272.100281\n",
       "4742          12.167077\n",
       "4521          56.998451"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_df = tabular_model.predict(test)\n",
    "pred_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Holdout MSE: 509.21296869752325 | Holdout MAE: 17.057807657741225\n"
     ]
    }
   ],
   "source": [
    "print_metrics(\n",
    "    [(mean_squared_error, \"MSE\", {}), (mean_absolute_error, \"MAE\", {})],\n",
    "    test[\"target\"],\n",
    "    pred_df[\"target_prediction\"],\n",
    "    tag=\"Holdout\",\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tabular_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "ad8d5d2789703c7b1c2f7bfaada1cbd3aa0ac53e2e4e1cae5da195f5520da229"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
